{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本分类器(torchtext处理数据)\n",
    "参考：[Torchtext 详细介绍 Part.1](https://zhuanlan.zhihu.com/p/37223078) , [github](https://github.com/keitakurita/practical-torchtext/blob/master/Lesson%201%20intro%20to%20torchtext%20with%20text%20classification.ipynb)\n",
    ", [手把手教你用torchtext处理文本数据](https://cloud.tencent.com/developer/article/1168890)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import spacy\n",
    "import torch\n",
    "import torchtext\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_root = os.path.join('../..', 'data','text')         # 设置数据的根目录\n",
    "data_list = ['train', 'val', 'test']                     # 设置train,val,test路径\n",
    "train_csv, val_csv, test_csv = map(lambda x:os.path.join(data_root, x+'.csv'), data_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据如下，“评论文本”与其对应的标签“恶意，严重恶意，淫秽，威胁，侮辱和身份仇恨”，该数据集可以训练一个**文本分类器**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment_text</th>\n",
       "      <th>toxic</th>\n",
       "      <th>severe_toxic</th>\n",
       "      <th>obscene</th>\n",
       "      <th>threat</th>\n",
       "      <th>insult</th>\n",
       "      <th>identity_hate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000997932d777bf</td>\n",
       "      <td>Explanation\\nWhy the edits made under my usern...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>000103f0d9cfb60f</td>\n",
       "      <td>D'aww! He matches this background colour I'm s...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id                                       comment_text  toxic  \\\n",
       "0  0000997932d777bf  Explanation\\nWhy the edits made under my usern...      0   \n",
       "1  000103f0d9cfb60f  D'aww! He matches this background colour I'm s...      0   \n",
       "\n",
       "   severe_toxic  obscene  threat  insult  identity_hate  \n",
       "0             0        0       0       0              0  \n",
       "1             0        0       0       0              0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(train_csv).head(2)                          # 显示训练集数据 \n",
    "# pd.read_csv(val_csv).head(2)                          # 验证集与训练集的数据类型一致"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment_text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>00001cee341fdb12</td>\n",
       "      <td>Yo bitch Ja Rule is more succesful then you'll...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0000247867823ef7</td>\n",
       "      <td>== From RfC == \\n\\n The title is fine as it is...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id                                       comment_text\n",
       "0  00001cee341fdb12  Yo bitch Ja Rule is more succesful then you'll...\n",
       "1  0000247867823ef7  == From RfC == \\n\\n The title is fine as it is..."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(test_csv).head(2)                           # 显示测试集数据 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.  声明Field对象定义数据预处理的pipline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 定义tokenizer函数\n",
    "进行数据清理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "NLP = spacy.load('en')               # 导入英语模型\n",
    "MAX_CHARS = 20000\n",
    "def tokenizer(comment):         \n",
    "    \"\"\"\n",
    "        func: 数据清洗及预处理，并返回token标记(字母与数字)\n",
    "        comment：传入需要处理的文本字符串(str)\n",
    "    \"\"\"\n",
    "    comment = re.sub(r'[\\*\\\"“”\\n\\.\\+\\-\\/\\=\\(\\)\\!;\\\\]', \" \", str(comment))   # 滤除无用的字符串\n",
    "    comment = re.sub(r'\\s+', ' ', comment)               # 将多个空格合并为一个空格\n",
    "    comment = re.sub(r'\\!+', '!', comment)               # 将多个 ‘!’ 合并为一个\n",
    "    comment = re.sub(r'\\,+', ',', comment)               # 同上\n",
    "    comment = re.sub(r'\\?+', '?', comment)               # 同上\n",
    "    \n",
    "    if len(comment) > MAX_CHARS:                         # 如果数据过长就截断\n",
    "        comment = comment[:MAX_CHARS] \n",
    "    # 仅返回字符和数字，滤除标点符号\n",
    "    return [token.text for token in NLP.tokenizer(comment) if not token.is_space  and not token.is_punct]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['abc', 'hahah', 'abc', 'def', 'class', 'printf', 'sub', 'fdf', '233']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试tokenizer函数\n",
    "test_doc = 'abc .... hahah,,, ..?? abc def class “ ” printf sub!! @  :  (\\) fdf / 233((()))'\n",
    "test_token = tokenizer(test_doc)\n",
    "test_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 定义Field对象\n",
    "声明数据预处理的pipeline\n",
    "\n",
    "如果数据是文本段落，需设置`sequential=True`及`use_vocab=True`，如果处理的数据已经是数字，那将这两个标志设为False即可，`batch_first=True`将batch放在第一维，sequence放到第二维\n",
    "\n",
    "更多参数参考：[官方注释](https://github.com/pytorch/text/blob/c839a7934930819be7e240ea972e4d600966afdc/torchtext/data/field.py#L61) 或 [blog：Torchtext指南 （侧重于NMT）](http://www.cnblogs.com/helloeboy/p/9882467.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.data import Field\n",
    "\n",
    "TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, use_vocab=True, batch_first=True)\n",
    "LABEL = Field(sequential=False, use_vocab=False, batch_first=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 构建数据集\n",
    "### 3.1数据预处理\n",
    "根据Field的声明来处理数据\n",
    "csv数据一共8列，数据处理时必须按照列的顺序，传入`[(name, field), ]`列表作为fields的参数，没有使用的列使用`None`来声明，如`('id', None)`，其他的列按照定义好的Field类来处理相应的类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 111 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from torchtext.data import TabularDataset\n",
    "\n",
    "train_val_fields = [('id', None),       ('comment_text', TEXT),\n",
    "                    ('toxic', LABEL),   ('severe_toxic', LABEL),\n",
    "                    ('obscene', LABEL), ('threat', LABEL),\n",
    "                    ('insult', LABEL),   ('identity_hate', LABEL)]\n",
    "train_dataset, val_dataset = TabularDataset.splits(      # 同时处理多个数据集，用splits\n",
    "    path = data_root,                                    # 数据所在的目录\n",
    "    train = 'train.csv', validation = 'val.csv',\n",
    "    format = 'csv',                                      # 指定处理文件格式\n",
    "    skip_header = True,                                  # 如果csv有表头，则设置该参数，跳过表头\n",
    "    fields = train_val_fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试数据没有标签所以处理方法与训练集和验证集不同"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 46.1 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "test_fields = [('id', None), ('comment_text', TEXT)]\n",
    "test_dataset = TabularDataset(                           # 只处理一个数据集\n",
    "    path = os.path.join(data_root, 'test.csv'),\n",
    "    format = 'csv',\n",
    "    skip_header = True,\n",
    "    fields = test_fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据集大多和列表一样可以使用**索引**的方式来访问得到某个样本，如`train_dataset[0]`，但每个样本数据都是**字典**形式来存储，使用`sample.comment_text`来访问，其中comment_text是数据集某一列的**表头**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torchtext.data.example.Example at 0x20b4861b9e8>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = train_dataset[0]             # 取出第一个数据\n",
    "sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample.__dict__.keys()                # 第一个数据以字典的形式存储 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.文本数据： ['explanation', 'why', 'the', 'edits', 'made', 'under', 'my']\n",
      "2.标签数据： 0\n"
     ]
    }
   ],
   "source": [
    "print('1.文本数据：', sample.comment_text[:7])       # 每个样本可以使用字典的形式来访问\n",
    "print('2.标签数据：',sample.toxic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 根据训练集来建立词典\n",
    "`stoi`属性返回{word:index,...}形式的**字典**(collections.defaultdict)，`itos`属性返回[word,...]形式的**列表**\n",
    "可以使用`max_size`和`min_freq`来表示词汇表(字典)中有多少单词和单词最少频率，为出现在词汇表的单词转换为`<unk>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 1 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "TEXT.build_vocab(train_dataset) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('the', 80), ('to', 41), ('you', 38), ('i', 32), ('of', 30)]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT.vocab.freqs.most_common(5)        # freqs是collections.Counter类型，统计出现频率高的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "collections.defaultdict"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(TEXT.vocab.stoi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>torchtext读取数据的几个类</center>\n",
    "\n",
    "|类|描述|使用场景\n",
    "|-|-|-|\n",
    "|TabularDataset          |处理的文件类型有csv/tsv、json及python的dict |每一条文本都有一个(或多个)标签的问题，如文本分类|\n",
    "|LanguageModelingDataset |以text文件的路径为输入                      |语言模型|\n",
    "|TranslationDataset      |以每种语言文件的路径或扩展名为输入，例英语文件是\"hoge.en\"，法语文件是\"hoge,fr\"，输入为path=\"hoge\", exts={\"en\", \"fr\"}|机器翻译|\n",
    "|SequenceTaggingDataset  |输入是由tabs分割的输入序列和输出序列的文件路径|序列标注|"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 创建迭代器\n",
    "上述步骤已经将数据读入到内存，在该步骤中将数据划分为批次，方便送入网络\n",
    "\n",
    "设置`sort_within_batch=False`，按照`sort_key`对每个批次内进行降序排列，用于对序列进行padding，将序列变为等长，这里的`sort_key = lambda token:len(token.comment_text)`语句表示根据单词长度来进行padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.data import Iterator, BucketIterator\n",
    "train_iter, val_iter = BucketIterator.splits((train_dataset, val_dataset),  # 传入数据集\n",
    "        batch_sizes = (25,25),                     # (train_batch, val_batch)\n",
    "        device = -1,                               # CPU:-1 GPU:指定GPU编号，从0开始\n",
    "        sort_key = lambda token:len(token.comment_text),   # 依据文本的长度对数据进行分组\n",
    "        sort_within_batch = False,\n",
    "        repeat = False)                            # 设置为False，可以对迭代器进行再次包装\n",
    "                                            \n",
    "test_iter = Iterator(test_dataset, batch_size=25, device=-1,\n",
    "                    sort=False, sort_within_batch=False, repeat=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[torchtext.data.batch.Batch of size 25]\n",
       "\t[.comment_text]:[torch.LongTensor of size 25x503]\n",
       "\t[.toxic]:[torch.LongTensor of size 25]\n",
       "\t[.severe_toxic]:[torch.LongTensor of size 25]\n",
       "\t[.obscene]:[torch.LongTensor of size 25]\n",
       "\t[.threat]:[torch.LongTensor of size 25]\n",
       "\t[.insult]:[torch.LongTensor of size 25]\n",
       "\t[.identity_hate]:[torch.LongTensor of size 25]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 读取批次化的数据\n",
    "sample_batch = next(iter(train_iter))\n",
    "sample_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['batch_size', 'dataset', 'train', 'fields', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_batch.__dict__.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 76,  90,   2,  ...,   1,   1,   1],\n",
       "        [200,  48, 198,  ...,   1,   1,   1],\n",
       "        [354,  50,  28,  ...,   1,   1,   1],\n",
       "        ...,\n",
       "        [553,  44, 671,  ...,   1,   1,   1],\n",
       "        [348,  72,   4,  ...,   1,   1,   1],\n",
       "        [196, 355, 245,  ...,   1,   1,   1]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_batch.comment_text         # 文本数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "        0])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_batch.toxic               # 标签数据 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 注意这种自定义数据类型使代码重用很难，当csv数据的列名(表头)改变时，代码也要改变"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 包装迭代器\n",
    "在上述数据集中有多个标签，而迭代器返回一个文本向量和多个标签向量，但训练模型时标签一般都是一个向量，所以需要将上述迭代器产生的结果进行包装，最终得到(x,y)形式的样本和标签对"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchWrapper(object):\n",
    "    def __init__(self, data_iter, x_var, y_vars):\n",
    "        \"\"\"\n",
    "        func: 对数据迭代器进行包装，返回(x,y)形式的批次形式\n",
    "        data_iter: 数据迭代器\n",
    "        x_var: 作为数据x的变量名称(列表)\n",
    "        y_vars：作为标签y的变量名称(列表)\n",
    "        \"\"\"\n",
    "        self.data_iter, self.x_var, self.y_vars = data_iter, x_var, y_vars\n",
    "    def __iter__(self):\n",
    "        for batch in self.data_iter:\n",
    "            x = getattr(batch, self.x_var)           # 提取数据的x_var属性，这里对应comment_text\n",
    "            \n",
    "            if self.y_vars is not None:              # 把y拼接为一个向量\n",
    "                # (N,) => (N,1) 再进行拼接\n",
    "                temp = [getattr(batch, attr).unsqueeze(1) for attr in self.y_vars]\n",
    "                y = torch.cat(temp, dim=1).float()   # 并转换类型，标签需要float类型\n",
    "            else:\n",
    "                y = torch.zeros((1))\n",
    "            yield (x, y)\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.data_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_var = 'comment_text'\n",
    "y_vars = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
    "train_loader = BatchWrapper(train_iter, x_var, y_vars)\n",
    "val_loader   = BatchWrapper(val_iter, x_var, y_vars)\n",
    "test_loader  = BatchWrapper(test_iter, x_var, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.x的tensor: torch.Size([25, 503])\n",
      "2.y的tensor: torch.Size([25, 6])\n"
     ]
    }
   ],
   "source": [
    "x_batch, y_batch = next(train_loader.__iter__())\n",
    "print('1.x的tensor:', x_batch.shape)\n",
    "print('2.y的tensor:', y_batch.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 训练文本分类器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device('cuda' if use_cuda else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 定义文本分类器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMBaseline(nn.Module):\n",
    "    def __init__(self, embed_size, hidden_size, num_linear=0):\n",
    "        \"\"\"\n",
    "        func: 定义简单的LSTM模型进行文本分类\n",
    "        embed_size：嵌入的维度大小\n",
    "        hidden_size：lstm隐藏状态的大小\n",
    "        num_linear：全连接的层数\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        # 词汇表的大小是len(TEXT.vocab)\n",
    "        self.embedding = nn.Embedding(len(TEXT.vocab), embed_size)      # (N,seq,embed_size)\n",
    "        self.encoder = nn.LSTM(embed_size, hidden_size, num_layers=2, batch_first=True, dropout=0.1)\n",
    "        self.linear_layers = []                       # 中间全连接层的列表 \n",
    "        # 将中间层添加到模型中\n",
    "        for _ in range(num_linear):\n",
    "            self.linear_layers.append(nn.Linear(hidden_size, hidden_size)) \n",
    "        self.linear_layers = nn.ModuleList(self.linear_layers)         # 转换为模型层\n",
    "        # 输出层\n",
    "        self.predictor = nn.Linear(hidden_size, 6)     # 6个标签，所以输出维度为6\n",
    "   \n",
    "    def forward(self, seq):\n",
    "        embedded_out = self.embedding(seq)             # (N,seq) => (N,seq,embed_size)\n",
    "        lstm_out, lsmt_hidden = self.encoder(embedded_out)     # (N,seq,hidden)\n",
    "        feature = lstm_out[:,-1,:]                     # 取最后step的输出 (N,hidden) \n",
    "        for layer in self.linear_layers:               # (N,hidden)\n",
    "            feature = layer(feature)\n",
    "        predicts = self.predictor(feature)             # (N,6) 未经过log_softmax处理\n",
    "        return predicts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 训练文本分类器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_size = 128\n",
    "hidden_size = 500\n",
    "lstm_model = LSTMBaseline(embedding_size, hidden_size)      # 定义模型\n",
    "criterion = nn.BCEWithLogitsLoss()                          # 使用二分类损失\n",
    "# criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(lstm_model.parameters(), lr=1e-3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, val_loader, epochs=1):\n",
    "    \"\"\"\n",
    "    func：训练模型\n",
    "    train_data：训练数据集\n",
    "    val_data：验证集\n",
    "    epochs：训练epochs\n",
    "    \"\"\"\n",
    "    pltLoss = []\n",
    "    running_loss = 0\n",
    "    GPU = lambda x:x.to(device)                        # 优先使用GPU\n",
    "    model.to(device)\n",
    "    for epoch in range(1, epochs+1):\n",
    "        model.train()                                  # 训练模式\n",
    "        for data in train_loader:\n",
    "            x, y = data\n",
    "            x, y = map(GPU, [x, y])\n",
    "            preds = model(x)\n",
    "            prob = 1 /(1+torch.exp(-preds.data) )\n",
    "\n",
    "            loss = criterion(preds, y)\n",
    "            print_loss = loss.item()                   # 获取loss数据            \n",
    "            optimizer.zero_grad()                      # 更新参数\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            running_loss += print_loss * x.shape[0]\n",
    "            \n",
    "        epoch_loss = running_loss / len(train_loader)           # 计算epoch的损失\n",
    "        running_loss = 0                                        # 切记epoch完成后清零\n",
    "        pltLoss.append(epoch_loss)\n",
    "        \n",
    "        # 计算验证集上的误差\n",
    "        val_loss = 0\n",
    "        model.eval()                                            # 模型进入测试模型\n",
    "        with torch.no_grad():\n",
    "            for x_val, y_val in val_loader:\n",
    "                x_val, y_val = map(GPU, [x_val, y_val])         # 数据放到GPU上\n",
    "                preds = model(x_val)       \n",
    "                loss = criterion(preds, y_val)\n",
    "                val_loss += loss.item() * x_val.shape[0]\n",
    "            epoch_val_loss = val_loss / len(val_loader)\n",
    "        print('Epoch: {:<3} Train Loss: {:<8.4f} Val Loss: {:<7.4f}'.format(\n",
    "            epoch, epoch_loss, epoch_val_loss))\n",
    "    return model, pltLoss          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\data\\field.py:322: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  return Variable(arr, volatile=not train)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1   Train Loss: 17.3588  Val Loss: 13.9564\n",
      "Epoch: 2   Train Loss: 13.8598  Val Loss: 8.7956 \n",
      "Epoch: 3   Train Loss: 8.4476   Val Loss: 5.5981 \n",
      "Epoch: 4   Train Loss: 4.9420   Val Loss: 5.1958 \n",
      "Epoch: 5   Train Loss: 4.2271   Val Loss: 5.2262 \n",
      "Epoch: 6   Train Loss: 3.8858   Val Loss: 5.4695 \n",
      "Epoch: 7   Train Loss: 3.9132   Val Loss: 5.6100 \n",
      "Epoch: 8   Train Loss: 3.9061   Val Loss: 5.6832 \n",
      "Epoch: 9   Train Loss: 3.8772   Val Loss: 5.7124 \n",
      "Epoch: 10  Train Loss: 3.8546   Val Loss: 5.6938 \n",
      "Wall time: 6.68 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 开始训练模型\n",
    "epochs = 10\n",
    "model, plt_loss= train(lstm_model, train_loader, val_loader, epochs=epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x20b4c986eb8>]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl8VfWd//HX52YhJBC2hB0SQGQVUKMCVuuCW13r0mqr4rS/YX6daUdtB9d2tE5rtTrTdn7TjbEurZZOi+gw1p/irlWUBpQdBFnDloSdQIAkn/njXiBAgAC595vc834+HnmQe+6557y5D8g753u+9xxzd0REJLpioQOIiEhYKgIRkYhTEYiIRJyKQEQk4lQEIiIRpyIQEYk4FYGkDTNbbmZjQucQaWlUBCIiEaciEAnIzDJDZxBREUhaMrNWZvZTM1uT+PqpmbVKPFdgZi+Z2WYz22hm75lZLPHc3Wa22sy2mdkiM7vwMNtvbWb/amYrzGyLmf0lsew8Mys7aN19Q1Zm9qCZTTKzZ81sK3Cfme00s4711j/VzCrNLCvx+GtmtsDMNpnZq2ZWlKS3TSJKRSDp6n5gJDACGA6cCXw38dx3gDKgEOgC3Ae4mQ0Avgmc4e5tgUuA5YfZ/uPA6cBooCNwF1DXyGxXA5OA9sBjwDTgunrPfwWY5O57zOyaRL5rE3nfAyY2cj8ijaIikHT1VeAhdy939wrg+8Atief2AN2AInff4+7vefyiW7VAK2CwmWW5+3J3/+zgDSeOHr4G3O7uq9291t0/cPddjcw2zd1fdPc6d98J/B64KbFtA25MLAP4O+BH7r7A3WuAh4EROiqQpqQikHTVHVhR7/GKxDKI/xa+BJhqZkvN7B4Ad18C3AE8CJSb2R/MrDuHKgBygENKopFWHfR4EjAqsa9zASf+mz9AEfCzxDDWZmAjYECP49y3yCFUBJKu1hD/IbpX78Qy3H2bu3/H3fsCVwLf3nsuwN1/7+6fS7zWgUcb2HYlUA30a+C5KiB37wMzyyA+pFPfAZf8dffNwFTgS8SHhSb6/ssCrwL+zt3b1/tq7e4fHPUdEGkkFYGkq4nAd82s0MwKgH8GngUwsyvM7KTEMMxW4kNCtWY2wMwuSJxUrgZ2Jp47gLvXAU8C/2Zm3c0sw8xGJV73KZBjZpcnTvZ+l/hw09H8HriV+LmC39db/ivgXjMbksjezsxuOI73Q+SwVASSrn4AlAKzgTnAzMQygP7A68B24idqf+HubxP/gf0I8d/41wGdiZ+obcg/Jbb7V+LDNY8CMXffAvw98ASwmvgRQtlhtlHflESu9e4+a+9Cd38hse0/JGYZzQUua8T2RBrNdGMaEZFo0xGBiEjEqQhERCIuaUVgZk+aWbmZza23bISZfWhmn5hZqZmdmaz9i4hI4yTziOBp4NKDlv0Y+L67jyA+i+PHSdy/iIg0QtIueOXu75pZ8cGLgfzE9+1IzOs+moKCAi8uPnhTIiJyJDNmzKh094M/x3KIVF/58A7gVTN7nPjRyOjDrWhm44BxAL1796a0tDQ1CUVE0oSZrTj6Wqk/WfwN4E537wXcCfzmcCu6+wR3L3H3ksLCoxaaiIgcp1QXwVhgcuL7PxG/IqSIiASU6iJYA3w+8f0FwOIU719ERA6StHMEZjYROA8oSNyo4wHgb4lfSTGT+LVcxiVr/yIi0jjJnDV002GeOj1Z+xQRkWOnTxaLiEScikBEJOLSughmrNjIL98+3ptIiYhEQ1oXwUuz1/LoKwv5aOmG0FFERJqttC6C8ZcMoHfHXO56fjY7dx9yoykRESHNiyA3O5NHrxvGig07eHzqotBxRESapbQuAoBR/Tpx88jePPn+Mmas2Bg6johIs5P2RQBwz2WD6N6uNeMnzaZ6j4aIRETqi0QRtGmVySPXncLSiip+8vqnoeOIiDQrkSgCgHP6F3LjGb34z3eX8smqzaHjiIg0G5EpAoD7Lh9El/wcxv9pFrtqNEQkIgIRK4L8nCwevvYUFpdv59/f0IVPRUQgYkUAcP6Azlx/ek9+9c5S5pRtCR1HRCS4yBUBwPcuH0ynvGzGT5rF7pq60HFERIKKZBG0y83i4S+ewsJ12/j5W0tCxxERCSqSRQAwZnAXrhnRnZ+/tYT5a7aGjiMiEkxkiwDggSuH0D43PkS0p1ZDRCISTZEugg552fzgmiHMW7OVX7+jy1WLSDRFuggALh3ajcuHdePf31jCp+u3hY4jIpJykS8CgIeuGkKbnEzG/2kWNRoiEpGISVoRmNmTZlZuZnMPWv4tM1tkZvPM7MfJ2v+x6NSmFd+/agizyrbwxF+WhY4jIpJSyTwieBq4tP4CMzsfuBoY5u5DgMeTuP9jcsWwblwypAv/9tqnLCnfHjqOiEjKJK0I3P1d4OAbAHwDeMTddyXWKU/W/o+VmfEv1wwlNzuDuybNorbOQ0cSEUmJVJ8jOBk4x8w+MrN3zOyMw61oZuPMrNTMSisqKlISrnPbHB64cjAzV27mqfc1RCQi0ZDqIsgEOgAjgfHAH83MGlrR3Se4e4m7lxQWFqYs4DUjenDhwM48PnURyyurUrZfEZFQUl0EZcBkj5sO1AEFKc5wRGbGD794ClkZMe56fjZ1GiISkTSX6iJ4EbgAwMxOBrKByhRnOKqu7XL43hWDmb5sI7/7cEXoOCIiSZXM6aMTgWnAADMrM7OvA08CfRNTSv8AjHX3Zvkr9w2n9+Tckwt59JWFrNq4I3QcEZGkSeasoZvcvZu7Z7l7T3f/jbvvdveb3X2ou5/m7m8ma/8nysx45NpTiJlx9/OzaaZ9JSJywvTJ4iPo3r41931hEB98toHfT18ZOo6ISFKoCI7ipjN7cfZJnXj4zwso26QhIhFJPyqCo4gPEQ3DgXsnz9EQkYikHRVBI/TqmMs9lw3kvcWV/LF0Veg4IiJNSkXQSDefVcRZfTryg5cWsHbLztBxRESajIqgkWIx49HrhrGnro77NEQkImlERXAMigvyGH/JQN5aVMHkmatDxxERaRIqgmN02+hiTi/qwPf/Zx7lW6tDxxEROWEqgmOUETN+fP0wdtXUcf+LczVEJCItnorgOPQrbMO3LzqZ1+avZ8qsNaHjiIicEBXBcfo/5/RleK/2PDhlHhXbdoWOIyJy3FQExykjZjx+/TCqdtXywJS5R3+BiEgzpSI4Af27tOX2Mf15ec46Xp6zNnQcEZHjoiI4QePO7cvQHvl878W5bKzaHTqOiMgxUxGcoKyMGI9dP5yt1Xt4cMq80HFERI6ZiqAJDOqWzzfP78+UWWuYOm9d6DgiIsdERdBE/v78fgzqls/9L85l8w4NEYlIy6EiaCLxIaJhbKrazUMvzQ8dR0Sk0VQETWhoj3Z847x+TJ65mjcXrg8dR0SkUVQETeybF5zEyV3acN/kuWyt3hM6jojIUSWtCMzsSTMrN7NDPm1lZv9kZm5mBcnafyitMjN47PrhlG+r5ocvLQgdR0TkqJJ5RPA0cOnBC82sF3ARkLZ3gx/eqz3jzu3Hf5Wu4t1PK0LHERE5oqQVgbu/C2xs4KmfAHcBaX3ZzjvG9KdfYR73Tp7D9l01oeOIiBxWSs8RmNlVwGp3n9WIdceZWamZlVZUtLzfqnOyMvjx9cNZs2UnP3pZQ0Qi0nylrAjMLBe4H/jnxqzv7hPcvcTdSwoLC5MbLklOL+rA18/uw3MfreSDJZWh44iINCiVRwT9gD7ALDNbDvQEZppZ1xRmSLnvXDyA4k65PPTSfN3ERkSapZQVgbvPcffO7l7s7sVAGXCau6f1NRlaZ2fwfz/fj4XrtjF9WUOnTEREwkrm9NGJwDRggJmVmdnXk7Wv5u7qET1on5vF0x8sDx1FROQQmcnasLvfdJTni5O17+amdXYGXz6jF0+8t4w1m3fSvX3r0JFERPbRJ4tT5JaRRbg7z364InQUEZEDqAhSpGeHXMYM6sLE6Sup3lMbOo6IyD4qghS67exiNu3Yw5RZa0JHERHZR0WQQqP6dmJAl7Y888FyTSUVkWZDRZBCZsato4uYt2YrM1ZsCh1HRARQEaTcF0/tQX5OJk9pKqmINBMqghTLzc7ky2f04pW561i3pTp0HBERFUEIt4wsps6d5z7SVFIRCU9FEEDvTrlcOLAzv/9IU0lFJDwVQSC3je7Dhqrd/Hn22tBRRCTiVASBnH1SJ07q3IZnpmkqqYiEpSIIxMwYO6qI2WVb+HjV5tBxRCTCVAQBXXtaT9q2yuTp95eHjiIiEaYiCCivVSY3lPTi5TlrKd+qqaQiEoaKILBbRxVR685zH60MHUVEIkpFEFhxQR7nnVzIcx+tZHdNXeg4IhJBKoJm4Laz+1C5fRcvz9FUUhFJPRVBM3DOSQX0LcjTrSxFJAgVQTMQixm3jirik1Wb+URTSUUkxVQEzcR1p/ckLzuDZ3RUICIppiJoJtrmZHFDSS9emr2Gim27QscRkQhJWhGY2ZNmVm5mc+ste8zMFprZbDN7wczaJ2v/LdGto4rYU+tMnK6ppCKSOsk8IngauPSgZa8BQ919GPApcG8S99/i9C1sw7knF/Lshys0lVREUiZpReDu7wIbD1o21d1rEg8/BHoma/8t1W2jiyjftotX5q0LHUVEIiLkOYKvAf//cE+a2TgzKzWz0oqKihTGCuu8kztT1ClXJ41FJGWCFIGZ3Q/UAM8dbh13n+DuJe5eUlhYmLpwgcWnkhYzY8Um5pRtCR1HRCIg5UVgZmOBK4Cvui7E36AbSnqSm52hD5iJSEqktAjM7FLgbuAqd9+Ryn23JPk5WVx3Wk/+Z/YaNmzXVFIRSa5kTh+dCEwDBphZmZl9HfgPoC3wmpl9Yma/Stb+W7qxo4vYXVPHH/66KnQUEUlzmcnasLvf1MDi3yRrf+nmpM5t+dxJBfxu2grGnduXrAx99k9EkkM/XZqxsaOLWbe1mqnz1oeOIiJpTEXQjF0wsDO9OrbWVFIRSSoVQTOWETNuHVnM9OUbmbdGU0lFJDlUBM3cl0p60TpLVyUVkeRRETRz7XKzuObUHvz3J2vYVLU7dBwRSUMqghbgttHF7NJUUhFJEhVBCzCga1tG9e3E76Ytp6ZWVyUVkaalImghxo4uZs2Wal5foKmkItK0GlUEZna7meVb3G/MbKaZXZzscLLfmEGd6dG+ta4/JCJNrrFHBF9z963AxUAh8DfAI0lLJYfIzIhxy6giPly6kYXrtoaOIyJppLFFYIk/vwA85e6z6i2TFPlySS9aZcY0lVREmlRji2CGmU0lXgSvmllbQGctU6xDXjbXjOjBCx+vZvMOTSUVkabR2CL4OnAPcEbi8tFZxIeHJMXGji6mek8dfyzVVFIRaRqNLYJRwCJ332xmNwPfBXTNgwAGd8/nzD4d+e20FdTW6b4+InLiGlsEvwR2mNlw4C5gBfDbpKWSI7ptdDFlm3byhqaSikgTaGwR1CRuK3k18DN3/xnxG8xIABcP7kK3djk8M2156CgikgYaWwTbzOxe4Bbgz2aWQfw8gQSQmRHj5pFFvL9kA4vXbwsdR0RauMYWwZeBXcQ/T7AO6AE8lrRUclQ3ndmb7MyYPmAmIiesUUWQ+OH/HNDOzK4Aqt1d5wgC6piXzVXDuzN55mq27NwTOo6ItGCNvcTEl4DpwA3Al4CPzOz6ZAaTo7ttdDE799TyJ00lFZET0NihofuJf4ZgrLvfCpwJfO9ILzCzJ82s3Mzm1lvW0cxeM7PFiT87HH90GdqjHSVFHTSVVEROSGOLIObu5fUeb2jEa58GLj1o2T3AG+7eH3gj8VhOwNjRxazcuIO3F5UffWURkQY0tgheMbNXzew2M7sN+DPw8pFe4O7vAhsPWnw18Ezi+2eAa44hqzTg0qFd6ZLfSieNReS4NfZk8XhgAjAMGA5McPe7j2N/Xdx9bWKba4HOh1vRzMaZWamZlVZUVBzHrqIhKyPGzWcV8d7iSpaUbw8dR0RaoEbfmMbdn3f3b7v7ne7+QjJDJfY3wd1L3L2ksLAw2btr0W46qzfZGTF+O2156Cgi0gIdsQjMbJuZbW3ga5uZHc9F8debWbfEtrsBGthuAgVtWnHFsG48P6OMbdWaSioix+aIReDubd09v4Gvtu6efxz7mwKMTXw/Fvjv49iGNGDs6GKqdtcyaUZZ6Cgi0sIk7Z7FZjYRmAYMMLMyM/s68buaXWRmi4GL0F3OmszwXu05tXd7nvlgOXWaSioixyAzWRt295sO89SFydpn1N02upjb//AJ7yyu4PwBhz0PLyJygKQdEUjqXTa0G4VtW+lWliJyTFQEaSQ7M8ZXz+rN24sqWFZZFTqOiLQQKoI085WzepOVYToqEJFGUxGkmc5tc/jCKd2YNKOM7btqQscRkRZARZCGxo4uZvuuGibP1FRSETk6FUEaOrVXe4b3bMfTmkoqIo2gIkhDZsbY0cUsrajiL0sqQ8cRkWZORZCmLh/WjYI22TppLCJHpSJIU60yM7jpzN68uaicFRs0lVREDk9FkMa+elYRGWb8dtqK0FFEpBlTEaSxru1yuHRoV/5YuooqTSUVkcNQEaS520YXs626hhc+Xh06iog0UyqCNHd6UQeG9sjnmQ+W466ppCJyKBVBmjMzxo4qZnH5dj74bEPoOCLSDKkIIuDK4d3pmJetG9yLSINUBBGQk5XBjWf04o0F61m1cUfoOCLSzKgIIuLmkUWYGb94+7PQUUSkmVERRET39q0ZO6qYidNXMk3nCkSkHhVBhIy/ZABFnXK5+/nZ7NitzxWISJyKIEJaZ2fwyLXDWLlxB4+/+mnoOCLSTAQpAjO708zmmdlcM5toZjkhckTRqH6duGVkEU99sIzS5RtDxxGRZiDlRWBmPYB/BErcfSiQAdyY6hxRds9lA+nerjV3TZpN9Z7a0HFEJLBQQ0OZQGszywRygTWBckRSXqtMHr1uGEsrq/jJ6xoiEom6lBeBu68GHgdWAmuBLe4+9eD1zGycmZWaWWlFRUWqY6a9z/Uv4MYzevGf7y7lk1WbQ8cRkYBCDA11AK4G+gDdgTwzu/ng9dx9gruXuHtJYWFhqmNGwn2XD6JLfg7j/zSLXTUaIhKJqhBDQ2OAZe5e4e57gMnA6AA5Ii8/J4uHv3gKi8u38x9vLgkdR0QCCVEEK4GRZpZrZgZcCCwIkEOA8wd25trTevCLtz9j7uotoeOISAAhzhF8BEwCZgJzEhkmpDqH7PfPVwymY1424yfNZndNXeg4IpJiQWYNufsD7j7Q3Ye6+y3uvitEDolrn5vND68ZyoK1W/nVO7oWkUjU6JPFAsDFQ7py1fDu/L83F7Nw3dbQcUQkhVQEss+DVw0hPyeL8X+aTU2thohEokJFIPt0zMvmoauHMmf1Fv7zvWWh44hIiqgI5ACXD+vGZUO78pPXP2VJ+fbQcUQkBVQEcoiHrh5KbnYG4yfNorZON7wXSXcqAjlEYdtWPHjlED5euZmn3tcQkUi6UxFIg64e0Z0xgzrz2KuLWFZZFTqOiCSRikAaZGb88IunkJ0Z4+7nZ1OnISKRtKUikMPqkp/D964YzPRlG/ndhytCxxGRJFERyBHdcHpPzj25kEdfWciqjTtCxxGRJFARyBGZGT+69hRiZtwzeTbuGiISSTcqAjmqHu1bc+8XBvL+kg1MnL4qdBwRaWIqAmmUr5zZm9H9OvHwywtYvXln6Dgi0oRUBNIoZsaj1w2jzp37Js/REJFIGlERSKP16pjL3ZcO5J1PK5g0oyx0HBFpIioCOSa3jCzizOKO/MtL81m/tTp0HBFpAioCOSaxmPHo9cPYVVPH/S9oiEgkHagI5Jj1Kchj/CUDeH1BOVNmrQkdR0ROkIpAjsvfnN2HU3u354Ep86jYpjuNirRkKgI5Lhkx47Hrh7Fjdy0PTJkbOo6InIAgRWBm7c1skpktNLMFZjYqRA45MSd1bssdY/rz8px1vDxnbeg4InKcQh0R/Ax4xd0HAsOBBYFyyAkad05fTunRju+9OJeNVbtDxxGR45DyIjCzfOBc4DcA7r7b3TenOoc0jcyMGI/dMIyt1Xt4cMq80HFE5DiEOCLoC1QAT5nZx2b2hJnlBcghTWRg13y+eX5/psxaw9R560LHEZFjFKIIMoHTgF+6+6lAFXDPwSuZ2TgzKzWz0oqKilRnlGP09+f3Y1C3fL774ly27NgTOo6IHIMQRVAGlLn7R4nHk4gXwwHcfYK7l7h7SWFhYUoDyrHLyojx2PXD2FC1m4demh86jogcg5QXgbuvA1aZ2YDEogsB/eRIA0N7tOMbn+/H8zPLeGtReeg4ItJIoWYNfQt4zsxmAyOAhwPlkCb2rQtPon/nNtw3eQ5bqzVEJNISBCkCd/8kMewzzN2vcfdNIXJI02uVmcFjNwxn/dZqfvSyZgWLtAT6ZLE0uRG92vO35/Rl4vRV/GVxZeg4InIUKgJJijsvOpm+BXnc/fxsqnbVhI4jIkegIpCkyMnK4MfXD2PNlp08+srC0HFE5AhUBJI0JcUduW10Mb+dtoIPl24IHUdEDkNFIEk1/pIB9O6Yy93Pz2bn7trQcUSkASoCSarc7EwevW4YKzbs4PGpi0LHEZEGqAgk6Ub168TNI3vz5PvLmLFiY+g4InIQFYGkxD2XDaJ7u9aMnzSb6j0aIhJpTlQEkhJtWmXyyHWnsLSiip++vjh0HBGpJzN0AImOc/oXcuMZvZjw7mfEDC4Z0pVTerQjFrPQ0UQiTUUgKXXf5YNYt7WaX7+7lF+8/Rld8lsxZlAXLh7SlZF9O9IqMyN0RJHIMXcPneGoSkpKvLS0NHQMaUKbqnbz1qJyXpu/nnc+rWDH7lratMrk8wMKuXhwF84b0Jl2rbNCxxRp0cxshruXHHU9FYGEVr2nlg8+q+S1+et5bX45ldt3kRkzRvbtxEWDuzBmcBd6tG8dOqZIi6MikBaprs75eNXmRCms47OKKgCGdM/nosFduHhwVwZ1a4uZziuIHI2KQNLCZxXbE6WwnpkrN+EOPdq3TpRCF87o05GsDE1+E2mIikDSTsW2Xby5MF4K7y2uZFdNHe1aZ3HBwM5cNLgL555cSJtWmv8gspeKQNLajt01vPtp/LzCmwvXs2nHHrIzYow+qRMXD+7KmEGd6ZyfEzqmSFAqAomMmto6ZqzYxNTEENLKjTuA+A1y9g4hndS5jc4rSOSoCCSS3J1P12/ntfnrmDp/PbPLtgDQpyCPiwZ34aLBXTitdwcy9CE2iQAVgQiwbks1ry2IHylM+6ySPbVOp7xsLhjYmQsHdWZg13x6dmhNpk44SxpSEYgcZFv1Ht5eVMFr89fz1qJytlXHb6GZGTN6d8qlb0EefQryKE782begDV3yW2lISVqsxhZBsCkWZpYBlAKr3f2KUDkkOtrmZHHl8O5cObw7u2vqmLN6M59VVLGssorllfE/985G2is3O4PiTvFi2PdVmEffgjza52YH/NuINJ2Qc+1uBxYA+QEzSERlZ8Y4vagjpxd1PGB5XZ2zdms1yyqqWFa5nWWVO1hWuZ15a7bwyrx11NbtP4LukJu17wgifjTRJvE4l9xsTWOVliPIv1Yz6wlcDvwQ+HaIDCINicWMHu1b06N9az7Xv+CA53bX1LFq0w6WVVSxfEMVSyurWFZRxQdLNjB55uoD1u3WLid+JFGYt2/IqU9BHr065uoDcNLshPq15afAXUDbw61gZuOAcQC9e/dOUSyRw8vOjNGvsA39Ctsc8tyO3TUsr9zBssr4kcTSxFDTy3PWsnnHnn3rZcSMXh1aJ4qhDX0K8+iWn0OdO3Xu1NZBrTt1dU5tne///oBl7FtWW3fg8zUHrEtimwe/fu9rOGDZ3tOFZmCJPwEMiy+z+PeJhYl19i2p97r9y7B6rz9gvf3b3LeXxDoxMzJiRsyMzAxLPIYMM2IxIzMW/zOjwfVs33oZMciIxRLr7d/uvq996+1/vHd7GbH9r9n72Ix9z8e/SCzf+7oD148ZLeb8UsqLwMyuAMrdfYaZnXe49dx9AjAB4ieLUxRP5LjkZmcyuHs+g7sfOtK5qWo3yzZUJYab9n99uHQjO5vwbm1m+39YZtT7YbT3B13skB90+5/f+wPMiU/B3csdHE/8uXeZx79vYNne9fe91vc/n3jJAdusv4zEsr3lVefES6pe6bVEDZXE4UolZkZsb2lZfL0fXTuMM/t0PPqOTkCII4KzgavM7AtADpBvZs+6+80BsogkXYe8bDrkZXNa7w4HLHd31m/dxfqt1Q3+JnroD+36v8nSwLKW8dvniah/JHTIkc4BR0kkvq+LH2XV+b6Cqan3/SHbqzvwyMzrlVLd3uf8oMd18QKrTRzV1dUrMa+//pGeq7ffg5fntUr+PTpSXgTufi9wL0DiiOCfVAISRWZG13Y5dG2nS2E0VixmxDCydP+iJqWzViIiERd0jpu7vw28HTKDiEjU6YhARCTiVAQiIhGnIhARiTgVgYhIxKkIREQiTkUgIhJxLeJ+BGZWAaw4zpcXAJVNGKel0/uxn96LA+n9OFA6vB9F7l54tJVaRBGcCDMrbcyNGaJC78d+ei8OpPfjQFF6PzQ0JCIScSoCEZGIi0IRTAgdoJnR+7Gf3osD6f04UGTej7Q/RyAiIkcWhSMCERE5AhWBiEjEpXURmNmlZrbIzJaY2T2h84RiZr3M7C0zW2Bm88zs9tCZmgMzyzCzj83spdBZQjOz9mY2ycwWJv6djAqdKRQzuzPx/2SumU00s7S/c1DaFoGZZQA/By4DBgM3mdngsKmCqQG+4+6DgJHAP0T4vajvdmBB6BDNxM+AV9x9IDCciL4vZtYD+EegxN2HAhnAjWFTJV/aFgFwJrDE3Ze6+27gD8DVgTMF4e5r3X1m4vttxP+T9wibKiwz6wlcDjwROktoZpYPnAv8BsDdd7v75rCpgsoEWptZJpALrAmcJ+nSuQh6AKvqPS4j4j/8AMysGDgV+ChskuB+CtwF1IUO0gz0BSqApxJDZU+YWV7oUCG4+2rgcWAlsBbY4u5Tw6ZKvnQuAmtgWaTnyppZG+B54A533xo6TyhmdgVQ7u4zQmdpJjKB04D4+DRjAAADFklEQVRfuvupQBUQyXNqZtaB+MhBH6A7kGdmN4dNlXzpXARlQK96j3sSgUO8wzGzLOIl8Jy7Tw6dJ7CzgavMbDnxIcMLzOzZsJGCKgPK3H3vUeIk4sUQRWOAZe5e4e57gMnA6MCZki6di+CvQH8z62Nm2cRP+EwJnCkIMzPi478L3P3fQucJzd3vdfee7l5M/N/Fm+6e9r/1HY67rwNWmdmAxKILgfkBI4W0EhhpZrmJ/zcXEoET55mhAySLu9eY2TeBV4mf+X/S3ecFjhXK2cAtwBwz+ySx7D53fzlgJmlevgU8l/ilaSnwN4HzBOHuH5nZJGAm8dl2HxOBS03oEhMiIhGXzkNDIiLSCCoCEZGIUxGIiEScikBEJOJUBCIiEaciEEkCMztPVzWVlkJFICIScSoCiTQzu9nMppvZJ2b268Q9Crab2b+a2Uwze8PMChPrjjCzD81stpm9kLguDWZ2kpm9bmazEq/pl9h8m3rX+H8u8UlVzOwRM5uf2M7jgf7qIvuoCCSyzGwQ8GXgbHcfAdQCXwXygJnufhrwDvBA4iW/Be5292HAnHrLnwN+7u7DiV+XZm1i+anAHcTvh9EXONvMOgJfBIYktvOD5P4tRY5ORSBRdiFwOvDXxKU3LiT+A7sO+K/EOs8CnzOzdkB7d38nsfwZ4Fwzawv0cPcXANy92t13JNaZ7u5l7l4HfAIUA1uBauAJM7sW2LuuSDAqAokyA55x9xGJrwHu/mAD6x3pOiwNXe58r131vq8FMt29hvhNk54HrgFeOcbMIk1ORSBR9gZwvZl1BjCzjmZWRPz/xfWJdb4C/MXdtwCbzOycxPJbgHcS93UoM7NrEttoZWa5h9th4p4Q7RIX/LsDGJGMv5jIsUjbq4+KHI27zzez7wJTzSwG7AH+gfiNWYaY2QxgC/HzCABjgV8lftDXv0LnLcCvzeyhxDZuOMJu2wL/nbghugF3NvFfS+SY6eqjIgcxs+3u3iZ0DpFU0dCQiEjE6YhARCTidEQgIhJxKgIRkYhTEYiIRJyKQEQk4lQEIiIR97/Jxv0Q1fH1awAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x20b4b012e80>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.title('loss curve')\n",
    "plt.xlabel('epochs')\n",
    "plt.ylabel('loss')\n",
    "plt.plot(plt_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. 评估模型性能\n",
    "在上述模型中，相当于进行了6个二分类，而网络输出的是未进行归一化的分数，所以为了得到概率表示，可以使用sigmoid处理分数向量，得到各自的概率，可以把每一类比作一个神经元，所以使用sigmoid来计算概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_preds = []\n",
    "model.eval()\n",
    "model.cpu()\n",
    "with torch.no_grad():\n",
    "    for x, y in test_loader:\n",
    "        preds = model(x)\n",
    "        preds = preds.data.numpy()                  # 将数据转换为numpy格式\n",
    "        probs = 1 / (1 + np.exp(-preds))            # sigmoid函数处理\n",
    "        test_preds.append(probs)\n",
    "test_preds = np.vstack(test_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(test_csv)\n",
    "for i, col in enumerate([\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]):\n",
    "    df[col] = test_preds[:, i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment_text</th>\n",
       "      <th>toxic</th>\n",
       "      <th>severe_toxic</th>\n",
       "      <th>obscene</th>\n",
       "      <th>threat</th>\n",
       "      <th>insult</th>\n",
       "      <th>identity_hate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>00001cee341fdb12</td>\n",
       "      <td>Yo bitch Ja Rule is more succesful then you'll...</td>\n",
       "      <td>0.151023</td>\n",
       "      <td>0.020645</td>\n",
       "      <td>0.015647</td>\n",
       "      <td>0.000875</td>\n",
       "      <td>0.014704</td>\n",
       "      <td>0.001049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0000247867823ef7</td>\n",
       "      <td>== From RfC == \\n\\n The title is fine as it is...</td>\n",
       "      <td>0.151023</td>\n",
       "      <td>0.020645</td>\n",
       "      <td>0.015647</td>\n",
       "      <td>0.000875</td>\n",
       "      <td>0.014704</td>\n",
       "      <td>0.001049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>00013b17ad220c46</td>\n",
       "      <td>\" \\n\\n == Sources == \\n\\n * Zawe Ashton on Lap...</td>\n",
       "      <td>0.151023</td>\n",
       "      <td>0.020645</td>\n",
       "      <td>0.015647</td>\n",
       "      <td>0.000875</td>\n",
       "      <td>0.014704</td>\n",
       "      <td>0.001049</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id                                       comment_text  \\\n",
       "0  00001cee341fdb12  Yo bitch Ja Rule is more succesful then you'll...   \n",
       "1  0000247867823ef7  == From RfC == \\n\\n The title is fine as it is...   \n",
       "2  00013b17ad220c46  \" \\n\\n == Sources == \\n\\n * Zawe Ashton on Lap...   \n",
       "\n",
       "      toxic  severe_toxic   obscene    threat    insult  identity_hate  \n",
       "0  0.151023      0.020645  0.015647  0.000875  0.014704       0.001049  \n",
       "1  0.151023      0.020645  0.015647  0.000875  0.014704       0.001049  \n",
       "2  0.151023      0.020645  0.015647  0.000875  0.014704       0.001049  "
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 总结：\n",
    "\n",
    "torchtext处理数据的步骤为：\n",
    "1. 定义Field类，声明数据预处理的方式\n",
    "2. 使用TabularDataset等数据集类按照Field的要求处理数据\n",
    "3. 使用BucketIterator或Iteration迭代器将数据集分成批次数据(此时数据就可以输入到网络了)\n",
    "4. 可以对迭代器再次进行包装，如这里的多标签向量进行合并得到一个表情向量，最终得到`(x,y)`数据和标签对\n",
    "\n",
    "spacy自然语言处理库可以进行分词、实体识别、和依赖关系划分等功能，并且速度很快"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
